/**
 * Copyright 2019-2022 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "transformer_utils.h"

// api/framework
#include "graph/op/const_defs.h"

// inc/framework
#include "infra/base/assertion.h"
#include "framework/graph/core/node/node_spec.h"
#include "framework/graph/core/node/node_compatibler.h"
#include "framework/graph/core/cgraph/graph_list_walker.h"
#include "framework/graph/debug/ge_graph_attr_define.h"

// src/framework
#include "base/common/tensor_util/trans_tensor.h"

namespace hiai {
// 转换函数
ge::GraphErrCodeStatus DataFormatToEnum(ge::OpDesc& desc)
{
    // data_format string ->int
    const std::map<string, int64_t> attrDataFormatMap = {
        {"NCHW", 0},
        {"NHWC", 1},
    };
    std::string DataFormat;
    if (ge::AttrUtils::GetStr(desc, "data_format", DataFormat)) {
        HIAI_EXPECT_TRUE(attrDataFormatMap.count(DataFormat) > 0);
        (void)ge::AttrUtils::SetInt(desc, "format", attrDataFormatMap.at(DataFormat));
    }
    return ge::GRAPH_SUCCESS;
}

ge::GraphErrCodeStatus TransformTypeConverter(ge::Node& node, const ConvertConfigInfo& config, bool isOldToNew)
{
    if (config.isOldToNew != isOldToNew) {
        return ge::GRAPH_SUCCESS;
    }
    ge::OpDesc& desc = node.ROLE(NodeSpec).OpDesc();
    if (config.dstOpType != "") {
        desc.SetType(config.dstOpType);
    }
    return ge::GRAPH_SUCCESS;
}

ge::GraphErrCodeStatus TypeAndAttrConverter(ge::Node& node, const ConvertConfigInfo& config, bool isOldToNew)
{
    HIAI_EXPECT_EXEC(TransformTypeConverter(node, config, isOldToNew));
    HIAI_EXPECT_EXEC(UpdateAttrConverter(node, config, isOldToNew));
    return ge::GRAPH_SUCCESS;
}

ge::GraphErrCodeStatus SetIntAttrDefValueConverter(ge::Node& node, const ConvertConfigInfo& config, bool isOldToNew)
{
    ge::OpDesc& desc = node.ROLE(NodeSpec).OpDesc();
    if (!isOldToNew) {
        // new ir -> old ir, set attr default value
        HIAI_EXPECT_TRUE(config.attrInfos.size() > 0);
        if (!config.attrInfos[0].srcDefValue.empty() && !desc.HasAttr(config.attrInfos[0].srcName)) {
            int64_t value = stoi(config.attrInfos[0].srcDefValue);
            (void)ge::AttrUtils::SetInt(desc, config.attrInfos[0].srcName, value);
        }
    }
    return ge::GRAPH_SUCCESS;
}

ge::GraphErrCodeStatus UpdateAttrConverter(ge::Node& node, const ConvertConfigInfo& config, bool isOldToNew)
{
    ge::OpDesc& desc = node.ROLE(NodeSpec).OpDesc();
    string srcName = "";
    string dstName = "";
    for (size_t i = 0; i < config.attrInfos.size(); i++) {
        if (isOldToNew) {
            srcName = config.attrInfos[i].srcName;
            dstName = config.attrInfos[i].dstName;
        } else {
            srcName = config.attrInfos[i].dstName;
            dstName = config.attrInfos[i].srcName;
        }
        AttrValue attrValue;
        if (srcName != "" && dstName != "" && desc.GetAttr(srcName, attrValue) == ge::GRAPH_SUCCESS) {
            HIAI_EXPECT_EXEC(desc.DelAttr(srcName));
            desc.SetAttr(dstName, attrValue);
        }
    }
    return ge::GRAPH_SUCCESS;
}

ge::GraphErrCodeStatus EmptyInputNodeRemoveForOMConverter(
    ge::Node& node, const ConvertConfigInfo& config, bool isOldToNew)
{
    (void)config;
    if (!isOldToNew) {
        return ge::GRAPH_SUCCESS;
    }
    if (node.ROLE(NodeSpec).IdleInputEndpoints().size() == 0) {
        return ge::GRAPH_SUCCESS;
    }
    return node.ROLE(NodeCompatibler).RemoveIdleEndpoint();
}

ge::GraphErrCodeStatus TransWeightsHALFToFloat(ge::Node& node)
{
    vector<ge::TensorPtr> weightsVec = ge::OpDescUtils::MutableWeights(node);
    for (const auto& weight : weightsVec) {
        HIAI_EXPECT_NOT_NULL(weight);

        ge::TensorDesc& weightTensorDesc = weight->MutableTensorDesc();
        if (weightTensorDesc.GetDataType() != ge::DT_FLOAT16) {
            continue;
        }
        ge::ccTensor_t srcWeightTensor;
        ge::ccTensor_t dstWeightTensor;
        srcWeightTensor.dataSize = weight->GetData().GetSize();
        dstWeightTensor.dataSize = weight->GetData().GetSize() * 2;

        uint32_t output_size = weight->GetData().GetSize() * 2;
        if (output_size == 0) {
            return ge::GRAPH_SUCCESS;
        }
        char* output = new (std::nothrow) char[output_size];
        HIAI_EXPECT_NOT_NULL(output);
        hiai::Status ret =
            TransTensorHALFToFloat(srcWeightTensor, weight->GetData().GetData(), dstWeightTensor, output);
        if (ret != hiai::SUCCESS) {
            FMK_LOGE("trans weight from fp16 to fp32 fail.");
            delete[] output;
            return ge::GRAPH_FAILED;
        }
        weight->SetData(reinterpret_cast<uint8_t*>(output), output_size);
        weightTensorDesc.SetDataType(ge::DT_FLOAT);
        delete[] output;
    }
    return ge::GRAPH_SUCCESS;
}

ge::GraphErrCodeStatus SplitGraphMergedWeight(ge::ComputeGraph& graph)
{
    int64_t srcSize = 0;
    int64_t srcAddr = 0;
    (void)ge::AttrUtils::GetInt(graph, SRC_MERGED_WEIGHT_SIZE, srcSize);
    (void)ge::AttrUtils::GetInt(graph, SRC_MERGED_WEIGHT_ADDR, srcAddr);

    auto visitor = [&srcSize, &srcAddr](ge::Node& node) {
        ge::OpDesc& opDesc = node.ROLE(NodeSpec).OpDesc();
        if (opDesc.GetType() == hiai::op::Const::TYPE) {
            ge::TensorPtr weight = nullptr;
            (void)ge::AttrUtils::MutableTensor(opDesc, "value", weight);
            HIAI_EXPECT_NOT_NULL(weight);
            if (weight->GetData().GetSize() != 0) {
                // already unmerged weight before this
                return hiai::SUCCESS;
            }

            int64_t offset = 0;
            HIAI_EXPECT_TRUE(ge::AttrUtils::GetInt(weight->GetTensorDesc(), "merged_offset", offset));

            uint32_t weightSize = ge::TensorUtils::GetWeightSize(weight->GetTensorDesc());
            HIAI_EXPECT_TRUE(offset >= 0 && offset <= UINT32_MAX && (static_cast<int64_t>(weightSize) <= srcSize) &&
                (srcSize - static_cast<int64_t>(weightSize) >= offset));

            // copy weight
            HIAI_EXPECT_EXEC(
                weight->SetData(reinterpret_cast<uint8_t*>(static_cast<uintptr_t>(srcAddr + offset)), weightSize));

            // rm {attr : {merge_offset} & weight_size
            ge::TensorUtils::DeleteAttr(weight->MutableTensorDesc(), "merged_offset");
            ge::TensorUtils::SetWeightSize(weight->MutableTensorDesc(), 0);
        }
        return hiai::SUCCESS;
    };

    HIAI_EXPECT_EXEC(graph.ROLE(GraphListWalker).WalkAllNodes(visitor));
    (void)graph.DelAttr(WEIGHT_MERGED);
    (void)graph.DelAttr(SRC_MERGED_WEIGHT_SIZE);
    (void)graph.DelAttr(SRC_MERGED_WEIGHT_ADDR);
    return ge::GRAPH_SUCCESS;
}

bool IsLiteV100Model(const ge::Node& node)
{
    const ge::ComputeGraph& graph = node.ROLE(NodeSpec).OwnerComputeGraph();

    bool isV100Model = false;
    (void)ge::AttrUtils::GetBool(graph, "graph_from_liteV100_model", isV100Model);

    return isV100Model;
}

bool IsCompiledModel(const ge::Node& node)
{
    const ge::ComputeGraph& graph = node.ROLE(NodeSpec).OwnerComputeGraph();

    if (IsLiteV100Model(node)) {
        return true;
    }

    bool isCompiledModel = false;
    (void)ge::AttrUtils::GetBool(graph, "Is_Compiled_Model", isCompiledModel);

    return isCompiledModel;
}
} // namespace hiai